from abc import ABC, abstractmethod
from typing import Dict, Tuple

import torch

from trinity.algorithm.utils import masked_mean
from trinity.utils.registry import Registry

ENTROPY_LOSS_FN = Registry("entropy_loss_fn")


class EntropyLossFn(ABC):
    """
    Entropy loss function.
    """

    @abstractmethod
    def __call__(
        self,
        entropy: torch.Tensor,
        action_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Args:
            entropy (`torch.Tensor`): The entropy generated by the policy model.
            action_mask (`torch.Tensor`): The action mask.

        Returns:
            `torch.Tensor`: The calculated entropy loss.
            `Dict`: The metrics for logging
        """

    @classmethod
    def default_args(cls) -> Dict:
        """
        Returns:
            `Dict`: The default arguments for the entropy loss function.
        """
        return {"entropy_coef": 0.0}


@ENTROPY_LOSS_FN.register_module("default")
class DefaultEntropyLossFn(EntropyLossFn):
    """
    Basic entropy loss function.
    """

    def __init__(self, entropy_coef: float):
        self.entropy_coef = entropy_coef

    def __call__(
        self,
        entropy: torch.Tensor,
        action_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        entropy_loss = masked_mean(entropy, action_mask)
        return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}


@ENTROPY_LOSS_FN.register_module("mix")
class MixEntropyLossFn(EntropyLossFn):
    """
    Basic entropy loss function for mix algorithm.
    """

    def __init__(self, entropy_coef: float):
        self.entropy_coef = entropy_coef

    def __call__(
        self,
        entropy: torch.Tensor,
        action_mask: torch.Tensor,
        expert_mask: torch.Tensor = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        if expert_mask is None:
            raise ValueError("expert_mask is required for MixEntropyLossFn")
        assert (
            len(expert_mask) == entropy.shape[0]
        ), f"Error: {len(expert_mask)=} != {entropy.shape[0]=}"
        entropy = entropy[~expert_mask]
        action_mask = action_mask[~expert_mask]
        entropy_loss = masked_mean(entropy, action_mask)
        return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}


@ENTROPY_LOSS_FN.register_module("none")
class DummyEntropyLossFn(EntropyLossFn):
    """
    Dummy entropy loss function.
    """

    def __init__(self, entropy_coef: float):
        self.entropy_coef = entropy_coef

    def __call__(
        self,
        entropy: torch.Tensor,
        action_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        return torch.tensor(0.0), {}
